from interaction_instructions import *
from agent_personas import *
import torch
import pickle
import os
from transformers import pipeline
import re
import argparse

def parse_lines_yaml(yaml, tag='state_attribute'):
    try:
        line_exp = re.findall(fr'{tag}_\d*:\s*\n*(.+)', yaml, re.IGNORECASE)
    except:
        yaml = yaml.replace(":\n", ": ")
        line_exp = []
        for line in yaml.split('\n'):
            try:
                line.index('_')
                i = line.index(': ')
                line_exp.append(line[i+2:])
            except:
                continue

    return line_exp

def parse_yes_no(yaml, x):
    # return answer
    explanation = re.findall(r'explanation:\s*\n*(.+)', yaml, re.IGNORECASE)[0]
    answer = re.findall(fr'{x}:\s*\n*(.+)', yaml, re.IGNORECASE)[0] == 'True'

    print(f'EXPLANATION: {explanation}')
    print(f'ANSWER: {answer}')
    return explanation, answer

def parse_correct_answer(yaml):
    # return answer
    relevant = re.findall(r'answer_addresses_question:\s*\n*(.+)', yaml, re.IGNORECASE)[0]
    logical = re.findall(r'answer_has_no_mistakes:\s*\n*(.+)', yaml, re.IGNORECASE)[0]
    explanation = re.findall(r'explanation:\s*\n*(.+)', yaml, re.IGNORECASE)[0]

    print(f'EXPLANATION: {explanation}')
    print(f"ANSWER: {(relevant == 'True') and (logical == 'True')}; (relevant -> {relevant} and logical -> {logical})")
    return explanation, (relevant == 'True') and (logical == 'True')

def log(text):
    with open(f'instruct_not_assist/{FILE_NAME}/log.txt', 'a+') as f:
        f.write('\nxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx\n')
        f.write(text)
        f.write('\nxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx\n')

class Verifier():
    def __init__(self):
        self.init_prompt = verifier_persona + problem_statement + correct_code
        self.model = pipeline("text-generation", 
                              model="meta-llama/Meta-Llama-3-8B-Instruct",
                              model_kwargs={"torch_dtype": torch.bfloat16},
                              trust_remote_code=True, 
                              device_map='auto')

    def prompt_verifier(self, prompt):
        messages = [
            {"role": "system", "content": self.init_prompt},
            {"role": "user", "content": prompt}]
        
        model_prompt = self.model.tokenizer.apply_chat_template(messages, 
                                                        tokenize=False, 
                                                        add_generation_prompt=True)

        terminators = [
            self.model.tokenizer.eos_token_id,
            self.model.tokenizer.convert_tokens_to_ids("<|eot_id|>")
        ]

        outputs = self.model(
            model_prompt,
            max_new_tokens=1024,
            eos_token_id=terminators,
            do_sample=False,
            pad_token_id=self.model.tokenizer.eos_token_id
        )
        message = outputs[0]["generated_text"][len(model_prompt):]

        print('prompted verifier')
        log('--- TO VERIFIER: ' + prompt)
        log(f'--- FROM VERIFIER: {message}')

        return message

    def prompt_verifier_state(self, prompt):
        messages = [
            {"role": "system", "content": verifier_persona},
            {"role": "user", "content": prompt + "\nInput:\n\n" + problem_statement + buggy_code + bug_description + bug_fixes + correct_code + "\nOutput:\n"}]
        
        model_prompt = self.model.tokenizer.apply_chat_template(messages, 
                                                        tokenize=False, 
                                                        add_generation_prompt=True)

        terminators = [
            self.model.tokenizer.eos_token_id,
            self.model.tokenizer.convert_tokens_to_ids("<|eot_id|>")
        ]

        outputs = self.model(
            model_prompt,
            max_new_tokens=1024,
            eos_token_id=terminators,
            do_sample=True,
            temperature=0.1,
            top_p=0.9,
            pad_token_id=self.model.tokenizer.eos_token_id
        )

        message = outputs[0]["generated_text"][len(model_prompt):]

        print('prompted verifier state')
        log('--- TO VERIFIER: ' + prompt)
        log(f'--- FROM VERIFIER: {message}')

        return message
    
    def get_state_repr(self):
        prompt = v2v_get_state_repr + yaml_state_repr
        response = self.prompt_verifier_state(prompt)

        return parse_lines_yaml(response)
    
    def assess_misunderstanding(self, instructor_exp, student_exp):
        # check for discrepancy
        x = "is_discrepancy"
        prompt = i2i_discrepancy(instructor_exp, student_exp) + yaml_yes_no(x)
        response = self.prompt_verifier(prompt)

        return parse_yes_no(response, x)
    
    def assess_understanding_of_curr_level(self, instructor_question, current_response):
        x = "did_student_answer_correctly"
        prompt = i2i_correct_answer_breakdown(instructor_question, current_response) + yaml_correct_answer
        response = self.prompt_verifier(prompt)
        return parse_correct_answer(response)

    def assess_state_level_understanding(self, instructor_question, current_response, target_representation, is_target):
        x = "student_has_sufficient_target_understanding"
        q = convo_history[-2]
        a = convo_history[-1]
        c = "\n".join(convo_history[:-2])
        if not is_target: # next state attribute that question is NOT conditionally generated on
            prompt = i2i_address_target_convo_qa(q, a, c, target_representation)
        else:
            prompt = i2i_address_target_qa(q, a, target_representation)
        response = self.prompt_verifier(prompt + yaml_yes_no(x))
        return parse_yes_no(response, x)

    def update_code(self, student_bug_fixes):
        instruction, final_prompt = i2i_apply_bug_fixes(buggy_code, student_bug_fixes)

        messages = [
            {"role": "system", "content": instruction},
            {"role": "user", "content": final_prompt + yaml_code_gen}]
        
        model_prompt = self.model.tokenizer.apply_chat_template(messages, 
                                                        tokenize=False, 
                                                        add_generation_prompt=True)

        terminators = [
            self.model.tokenizer.eos_token_id,
            self.model.tokenizer.convert_tokens_to_ids("<|eot_id|>")
        ]

        outputs = self.model(
            model_prompt,
            max_new_tokens=1024,
            eos_token_id=terminators,
            do_sample=False,
            pad_token_id=self.model.tokenizer.eos_token_id
        )
        message = outputs[0]["generated_text"][len(model_prompt):]

        print('prompted verifier')
        log('--- TO VERIFIER: ' + final_prompt + yaml_code_gen)
        log(f'--- FROM VERIFIER: {message}')

        yaml = message.replace(": \n", ":\n")
        yaml = re.findall(r'(?<=correct_code:\n)((?:.|\n)*?)(?=---)', yaml, re.IGNORECASE)[0]

        log(f'--- PARSED CODE:\n{yaml}')
        return yaml
    
    def check_code(self, new_code):
        x = "are_snippets_logically_equivalent"
        messages = [
            {"role": "system", "content": problem_statement},
            {"role": "user", "content": v2v_check_logical_eq(unit_tests, new_code, correct_code) + yaml_yes_no(x)}]
        
        model_prompt = self.model.tokenizer.apply_chat_template(messages, 
                                                        tokenize=False, 
                                                        add_generation_prompt=True)

        terminators = [
            self.model.tokenizer.eos_token_id,
            self.model.tokenizer.convert_tokens_to_ids("<|eot_id|>")
        ]

        outputs = self.model(
            model_prompt,
            max_new_tokens=1024,
            eos_token_id=terminators,
            do_sample=False,
            pad_token_id=self.model.tokenizer.eos_token_id
        )

        message = outputs[0]["generated_text"][len(model_prompt):]

        print('prompted verifier state')
        log('--- TO VERIFIER: ' + model_prompt)
        log(f'--- FROM VERIFIER: {message}')

        return parse_yes_no(message, x)
        

    
class Instructor():
    def __init__(self, model=None):
        self.init_prompt = lambda bc: instructor_persona + bug_fixes + problem_statement + bc + bug_description
        if model is None:
            self.model = pipeline("text-generation", 
                                model="meta-llama/Meta-Llama-3-8B-Instruct",
                                model_kwargs={"torch_dtype": torch.bfloat16},
                                trust_remote_code=True, 
                                device_map='auto')
        else:
            self.model = model


    def prompt_instructor(self, prompt):
        messages = [
            {"role": "system", "content": self.init_prompt(buggy_code)},
            {"role": "user", "content": prompt}]
        
        model_prompt = self.model.tokenizer.apply_chat_template(messages, 
                                                        tokenize=False, 
                                                        add_generation_prompt=True)

        terminators = [
            self.model.tokenizer.eos_token_id,
            self.model.tokenizer.convert_tokens_to_ids("<|eot_id|>")
        ]

        outputs = self.model(
            model_prompt,
            max_new_tokens=1024,
            eos_token_id=terminators,
            do_sample=True,
            temperature=0.3,
            top_p=0.9,
            pad_token_id=self.model.tokenizer.eos_token_id
        )
        message = outputs[0]["generated_text"][len(model_prompt):]

        print('prompted instructor')
        log('--- TO INSTRUCTOR: ' + prompt)
        log(f'--- FROM INSTRUCTOR: {message}')
        return message
        
    
    def generate_candidate_questions(self, prev_qs=None, target=None, explanation="", tag="initial"):
        # conditional question genration
        ch = "\n".join(convo_history)
        if tag == "initial":
            prompt = i2i_generate_initial_q(target) + yaml_cqg
        elif tag == "same":
            prompt = i2i_generate_sibling(prev_qs, ch, bug_fixes, bug_description, target, explanation) + yaml_cqg
        else: # next level of questions
            prompt = i2i_generate_child(prev_qs, ch, bug_fixes, bug_description, target, explanation) + yaml_cqg

        candidate_questions = self.prompt_instructor(prompt)
        candidate_questions = parse_lines_yaml(candidate_questions, tag='question')

        return candidate_questions
    
class Student():
    def __init__(self):
        self.model = pipeline("text-generation", 
                              model="mistralai/Mistral-7B-Instruct-v0.2", 
                              trust_remote_code=True, 
                              device_map='auto')
        self.init_prompt = lambda bc: problem_statement + bc + student_persona
    
    def prompt_student(self, prompt, suffix="", do_sample=True):
        final_prompt = "<s>[INST]" + self.init_prompt(buggy_code) + prompt + "[/INST]" + suffix
        if do_sample:
            response = self.model(final_prompt, 
                                do_sample=do_sample,
                                top_k=10,
                                num_return_sequences=1, 
                                max_new_tokens=200,
                                pad_token_id=self.model.tokenizer.eos_token_id)[0]
        else:
            response = self.model(final_prompt, 
                                do_sample=do_sample,
                                num_return_sequences=1, 
                                max_new_tokens=200,
                                pad_token_id=self.model.tokenizer.eos_token_id)[0]
            
        print('prompted student')
        log('--- TO STUDENT: ' + prompt)
        log('--- FROM STUDENT: ' + response['generated_text'].split('[/INST]')[-1])
        return response['generated_text']
        # response = ''
        # return response
    
    def parse_student_exp(self, yaml):
        yaml = yaml.replace(":\n", ": ").replace(": \n", ": ")
        yaml = yaml.split('[/INST]')[-1]
        yaml = re.findall(r'line_explanation: (.+)\n', yaml, re.IGNORECASE)[0] #yaml.split('[/INST] line_explanation: ')[1]
        if "\n" in yaml:
            yaml = yaml.split("\n")[0]
        print(yaml)
        return yaml
    
    def parse_student_answer(self, yaml):
        yaml = yaml.replace(":\n", ": ").replace(": \n", ": ")
        yaml = yaml.split('[/INST]')[-1]
        yaml = re.findall(r'student_answer:\s*\n*([\S\s]*)', yaml, re.IGNORECASE)[0]
        if "\n" in yaml:
            yaml = yaml.split("\n")[0]
        print(yaml)
        return yaml

    def generate_bug_fixes(self):
        prompt = i2s_generate_bug_fixes('\n'.join(convo_history))
        yaml = self.prompt_student(prompt, do_sample=False)
        yaml = yaml.split('[/INST]')[-1]
        yaml = re.findall(r'bug_fix_.:\s*(.*)', yaml, re.IGNORECASE)
        return yaml

    def ask_student(self, question):
        # response = self.prompt_student(question + yaml_student_answer, "\nstudent_answer: ")
        response = self.prompt_student(f"question: {question}\n" + yaml_student_answer, suffix="\nstudent_answer: ")
        return self.parse_student_answer(response)

def fix_misunderstanding(student: Student, instructor: Instructor, verifier: Verifier, state_representation, target_rep):
    global num_questions
    global convo_history
    global buggy_code
    level = 0 # level 0 is asking about misunderstandings with code
    level_questions = {}
    level_indices = {}
    is_student_done = False
    level_explanations = {}

    # error handling
    level_questions[-1] = [f'Can you walk through the logic of your code?']

    prefix = ""

    # discrepancy, _ = verifier.assess_misunderstanding(instructor_exp, student_exp)
    candidate_questions = instructor.generate_candidate_questions(target=target_rep, tag="initial")

    while not is_student_done:
        # assess core misunderstanding -> get the k questions to ask student     
        if level not in level_questions.keys():
            level_questions[level] = candidate_questions
            level_indices[level] = 0 # setting i

        # get student answer for question i at level l
        # TODO: DEBUG THIS
        instructor_question = prefix + level_questions[level][level_indices[level]]
        convo_history.append("Instructor: " + instructor_question)
        student_question_response = student.ask_student(instructor_question)
        convo_history.append("Student: " + student_question_response)
        level_indices[level] += 1
        
        # use verifier to see if student understands the curr level questions
        clu_explanation, is_curr_level_understand = verifier.assess_understanding_of_curr_level(instructor_question, student_question_response)
        if is_curr_level_understand:
            # TODO: ADD CONDITION THAT TAKES IN TARGET ATTRIBUTE AND SEE IF IT IS RESOLVED --> STOP
            # [[attribute_1, 2, ...], [False, False,..]]
            idx = state_representation[0].index(target_rep)
            is_target = True
            for i in range(idx, len(state_representation[0])):
                # is this state attribute actually resolved?
                exp, flag = verifier.assess_state_level_understanding(instructor_question, student_question_response, state_representation[0][i], is_target)
                is_target = False
                # if it is resolved -> update state representation and move onto next attribute
                # if it is not resolved -> if no progression (i == idx), just prefix_next_level;
                #                       -> if progression (i > idx), ask student to generate bug fixes + ask instructor/verifier to update the code -> return state_repr + new code
                if flag:
                    state_representation[1][i] = True
                else:
                    level_explanations[level] = exp
                    if i == idx:
                        level += 1
                        prefix = prefix_next_level
                    else:
                        is_student_done = True
                    break
            if state_representation[1][-1]:
                is_student_done = True
        else:
            # no change in level
            prefix = prefix_same_level
            level_explanations[level] = clu_explanation

        with open(f'instruct_not_assist/{FILE_NAME}/convo.txt', 'a+') as f:
            try:
                f.write(f'Instructor: {instructor_question}\n')
                f.write(f'Student: {student_question_response}\n')
                # f.write(f'Verifier: \n\t Previous Level: {is_prev_level_understand}, {plu_explanation}\n')
                f.write(f'\tCurrent Target: {target_rep}\n')
                f.write(f'\tCurrent Level: {is_curr_level_understand}, {clu_explanation}\n')
                f.write(f'\tCurrent State Representation: {state_representation[1]}\n\n')
            except:
                f.write(f'\tCurrent Level: N/A, N/A\n')
        
        # generate new questions
        if not is_student_done:
            if not is_curr_level_understand: # same level
                candidate_questions = instructor.generate_candidate_questions(prev_qs='\n'.join(level_questions[level]), target=target_rep, explanation=level_explanations[level], tag="same")
                level_questions[level].extend(candidate_questions)
            else: # next level
                candidate_questions = instructor.generate_candidate_questions(prev_qs='\n'.join(level_questions[level - 1]), target=target_rep, explanation=level_explanations[level-1], tag="next")
                

    # generate new code
    student_bug_fixes = student.generate_bug_fixes()
    with open(f'instruct_not_assist/{FILE_NAME}/bug_fixes.txt', 'a+') as f:
        f.write(f'{student_bug_fixes}\n')
    if len(student_bug_fixes):
        new_code = verifier.update_code(student_bug_fixes)
        _, is_code_correct = verifier.check_code(new_code)

        if is_code_correct:
            state_representation[1] = [True]*len(state_representation[1])
    else:
        new_code = buggy_code
    return state_representation, new_code

def run():
    verifier = Verifier()
    instructor = Instructor(verifier.model)
    student = Student()

    did_student_understand = False

    # get state representation based on student initial progress
    # starting point: buggy code
    # how to we get to ending point: correct code
    
    # TODO: CHECK STATE REPRESENTATION FOR MULTIPLE BUGS AND IF WE NEED TO BUILD IT OUT BLOCK BY BLOCK
    state_attributes = verifier.get_state_repr()
    state_repr = [state_attributes, [False]*len(state_attributes)]
    temp = ""
    for x, y in zip(state_repr[0], state_repr[1]):
        temp += f"{x}, {y}\n"
    
    log(f"State Representation: {temp}")

    global num_questions
    global convo_history
    global buggy_code
    
    while not did_student_understand:
        if False in state_repr[1]:
            target = state_repr[0][state_repr[1].index(False)]
            state_repr, student_new_code = fix_misunderstanding(student, instructor, verifier, state_repr, target)
            buggy_code = student_new_code
            convo_history = []

            temp = ""
            for x, y in zip(state_repr[0], state_repr[1]):
                temp += f"{x}, {y}\n"
            log(f"State Representation: {temp}")
        else:
            did_student_understand = True
            with open(f'instruct_not_assist/{FILE_NAME}/correct_code.txt', 'w+') as f:
                f.write(buggy_code)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--file', type=str, default='../data_pkls/15_44_sequential_search_conversational_thread_1.pkl')
    parser.add_argument('--bug_num', type=int, default=1)

    args = parser.parse_args()

    FILE_NAME = re.findall(r'data_pkls/(.+).pkl', args.file, re.IGNORECASE)[0]
    # args.file[args.file.index('/') + 1:].split('.')[0]

    try:
        os.mkdir(f'instruct_not_assist/{FILE_NAME}')
    except OSError:
        if os.path.exists(os.path.join(f'instruct_not_assist/{FILE_NAME}', 'log.txt')):
            os.remove(os.path.join(f'instruct_not_assist/{FILE_NAME}', 'log.txt'))

        if os.path.exists(os.path.join(f'instruct_not_assist/{FILE_NAME}', 'bug_fixes.txt')):
            os.remove(os.path.join(f'instruct_not_assist/{FILE_NAME}', 'bug_fixes.txt'))

        if os.path.exists(os.path.join(f'instruct_not_assist/{FILE_NAME}', 'convo.txt')):
            os.remove(os.path.join(f'instruct_not_assist/{FILE_NAME}', 'convo.txt'))

        if os.path.exists(os.path.join(f'instruct_not_assist/{FILE_NAME}', 'correct_code.txt')):
            os.remove(os.path.join(f'instruct_not_assist/{FILE_NAME}', 'correct_code.txt'))

    extracted_data = pickle.load(open(args.file, 'rb'))
    problem_statement = extracted_data['problem']
    buggy_code = extracted_data['buggy_code']

    bug_fixes = extracted_data['bug_fixes']
    if args.bug_num == 1:
        first_bug_fix = re.findall(r'^---\nbug_fixes:\n([\S\s]*)\n---\n$', bug_fixes, re.IGNORECASE)[0].split("\n")[0]
        bug_fixes = re.sub(r'^---\nbug_fixes:\n[\S\s]*\n---\n$', f'---\nbug_fixes:\n{first_bug_fix}\n---\n', bug_fixes)
    else:
        bug_fixes = re.findall(r'^---\nbug_fixes:\n([\S\s]*)\n---\n$', bug_fixes, re.IGNORECASE)[0].split("\n")


    bug_description = extracted_data['bug_desc'] # not a typo
    correct_code = extracted_data['correct_code']
    unit_tests = extracted_data['unit_tests']

    num_questions = 3 # k
    convo_history = []

    log(f"problem statement:\n{problem_statement}\nbuggy_code:\n{buggy_code}\ncorrect_code:\n{correct_code}\nbug_fixes:\n{bug_fixes}")

    run()

